package com.jstarcraft.ai.jsat.datatransform;

import java.util.ArrayList;
import java.util.List;

import com.jstarcraft.ai.jsat.DataSet;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.linear.Vec;
import com.jstarcraft.ai.jsat.parameters.Parameterized;
import com.jstarcraft.ai.jsat.parameters.Parameter.ParameterHolder;

/**
 * Performing a transform on the whole data set before training a classifier can
 * add bias to the results. For proper evaluation, the transforms must be
 * learned from the training set and not contain any knowledge from the testing
 * set. A DataTransformProcess aids in this by providing a mechanism to contain
 * several different transforms to learn and then apply. <br>
 * <br>
 * The Parameters of the Data Transform Process are the parameters from the
 * individual transform factories that make up the whole process. The name
 * "DataTransformProcess" will not be prefixed to the parameter names.
 * 
 * @author Edward Raff
 */
public class DataTransformProcess implements DataTransform, Parameterized {

    private static final long serialVersionUID = -2844495690944305885L;
    @ParameterHolder(skipSelfNamePrefix = true)
    private List<DataTransform> transformSource;
    private List<DataTransform> learnedTransforms;

    /**
     * Creates a new transform process that is empty. Transform factories must be
     * added using {@link #addTransform(jsat.datatransform.DataTransformFactory) }.
     */
    public DataTransformProcess() {
        transformSource = new ArrayList<>();
        learnedTransforms = new ArrayList<>();
    }

    /**
     * Creates a new transform process from the listed factories, which will be
     * applied in order by index.
     * 
     * @param transforms the array of factories to apply as the data transform
     *                   process
     */
    public DataTransformProcess(DataTransform... transforms) {
        this();
        for (DataTransform dt : transforms)
            this.addTransform(dt);
    }

    /**
     * Adds a transform to the list of transforms. Transforms are learned and
     * applied in the order in which they are added.
     * 
     * @param transform the factory for the transform to add
     */
    public void addTransform(DataTransform transform) {
        transformSource.add(transform);
    }

    /**
     * 
     * @return the number of transforms currently chained in this transform process
     */
    public int getNumberOfTransforms() {
        return transformSource.size();
    }

    /**
     * Consolidates transformation objects when possible. Currently only works with
     * {@link RemoveAttributeTransform}
     */
    private void consolidateTransforms() {
        for (int i = 0; i < learnedTransforms.size() - 1; i++) {
            DataTransform t1 = learnedTransforms.get(i);
            DataTransform t2 = learnedTransforms.get(i + 1);
            if (!(t1 instanceof RemoveAttributeTransform && t2 instanceof RemoveAttributeTransform))
                continue;// They are not both RATs
            RemoveAttributeTransform r1 = (RemoveAttributeTransform) t1;
            RemoveAttributeTransform r2 = (RemoveAttributeTransform) t2;

            r2.consolidate(r1);
            learnedTransforms.remove(i);
            i--;
        }
    }

    @Override
    public void fit(DataSet data) {
        learnApplyTransforms(data);
    }

    /**
     * Learns the transforms for the given data set. The data set will not be
     * altered. Once finished, <tt>this</tt> DataTransformProcess can be applied to
     * the dataSet to get the transformed data set.
     * 
     * @param dataSet the data set to learn a series of transforms from
     */
    public void leanTransforms(DataSet dataSet) {
        learnApplyTransforms(dataSet.shallowClone());
    }

    /**
     * Learns the transforms for the given data set. The data set is then altered
     * after each transform is learned so the next transform can be learned as well.
     * <br>
     * The results are equivalent to calling
     * {@link #learnApplyTransforms(com.jstarcraft.ai.jsat.DataSet) } on the data
     * set and then calling
     * {@link DataSet#applyTransform(jsat.datatransform.DataTransform) } with this
     * DataTransformProces.
     * 
     * @param dataSet the data set to learn a series of transforms from and alter
     *                into the final transformed form
     */
    public void learnApplyTransforms(DataSet dataSet) {
        learnedTransforms.clear();
        // used to keep track if we can start using in place transforms
        boolean vecSafe = false;
        boolean catSafe = false;
        int iter = 0;

        // copy original references so we can check saftey of inplace mutation later
        Vec[] origVecs = new Vec[dataSet.size()];
        int[][] origCats = new int[dataSet.size()][];
        for (int i = 0; i < origVecs.length; i++) {
            DataPoint dp = dataSet.getDataPoint(i);
            origVecs[i] = dp.getNumericalValues();
            origCats[i] = dp.getCategoricalValues();
        }

        for (DataTransform dtf : transformSource) {
            DataTransform transform = dtf.clone();
            transform.fit(dataSet);
            if (transform instanceof InPlaceTransform) {
                InPlaceTransform ipt = (InPlaceTransform) transform;
                // check if it is safe to apply mutations
                if (iter > 0 && !vecSafe || (ipt.mutatesNominal() && !catSafe)) {
                    boolean vecClear = true, catClear = true;
                    for (int i = 0; i < origVecs.length && (vecClear || catClear); i++) {
                        DataPoint dp = dataSet.getDataPoint(i);
                        vecClear = origVecs[i] != dp.getNumericalValues();
                        catClear = origCats[i] != dp.getCategoricalValues();
                    }

                    vecSafe = vecClear;
                    catSafe = catClear;
                }

                // Now we know if we can apply the mutations or not
                if (vecSafe && (!ipt.mutatesNominal() || catSafe))
                    dataSet.applyTransformMutate(ipt, true);
                else// go back to normal
                    dataSet.applyTransform(transform);
            } else
                dataSet.applyTransform(transform);

            learnedTransforms.add(transform);
            iter++;
        }
        consolidateTransforms();
    }

    @Override
    public DataPoint transform(DataPoint dp) {
        final Vec origNum = dp.getNumericalValues();
        final int[] origCat = dp.getCategoricalValues();
        for (DataTransform dt : learnedTransforms) {
            if (dt instanceof InPlaceTransform) {
                InPlaceTransform it = (InPlaceTransform) dt;
                // check if we can safley mutableTransform instead of allocate
                if (origNum != dp.getNumericalValues() && (!it.mutatesNominal() || origCat != dp.getCategoricalValues())) {
                    it.mutableTransform(dp);
                    continue;
                }
            }
            dp = dt.transform(dp);
        }
        return dp;
    }

    @Override
    public DataTransformProcess clone() {
        DataTransformProcess clone = new DataTransformProcess();

        for (DataTransform dtf : this.transformSource)
            clone.transformSource.add(dtf.clone());

        for (DataTransform dt : this.learnedTransforms)
            clone.learnedTransforms.add(dt.clone());

        return clone;
    }
}
